#!/usr/bin/env python

import math
import copy
import rospy
import threading
import numpy as np
import modern_robotics as mr
from interbotix_ux_modules import mr_descriptions as mrd
from interbotix_uxarm_joy.msg import ArmJoy
from interbotix_common_modules import angle_manipulation as ang
from interbotix_ux_modules.gripper import InterbotixGripperUX

### @brief Class to control a Universal Factory Xarm robot with a joystick controller
class ArmRobot(object):
    def __init__(self):
        self.rotate_step = 0.0175                                           # Step size [rad] for each iteration of the control loop when changing orientation [roll, pitch, yaw]
        self.translate_step = 0.0055                                        # Step size [m] for each iteration of the control loop when changing position [x, y, z]
        self.arm_inc = 10                                                   # Current step factor (between 0 and 20) used to change the step sizes above to be larger or smaller (changing the speed)
        self.arm_incs = {"coarse": 10, "fine": 10}                          # Holds the step factors for 'coarse' and 'fine' control
        self.gripper_inc = 25                                               # Step size [pulse] for each iteration of the control loop when changing gripper position
        self.gripper_command = 850                                          # Current gripper pulse (ranges between 0 and 850)
        self.current_mode_status = "servo"                                  # Current Xarm mode ("servo" or "teach")
        self.initial_pose_cmd = True                                        # True if this is the first iteration when executing a pose_cmd; False otherwise
        self.joint_positions = []                                           # Keeps track of joint position commands when executing a pose_cmd
        self.joy_msg = ArmJoy()                                             # Custom ROS message to recieve processed Joystick commands
        self.joy_mutex = threading.Lock()                                   # Mutex to prevent reading/writing to the self.joy_msg simultaneously
        robot_model = rospy.get_param("~robot_model")
        robot_name = rospy.get_namespace().strip("/")
        ee_offset = rospy.get_param("~ee_offset")
        self.use_gripper = rospy.get_param("~use_gripper")
        gripper = "gripper" if self.use_gripper else None
        self.armbot = InterbotixGripperUX(robot_model, robot_name=robot_name, mode=1, wait_for_finish=True, ee_offset=ee_offset, pulse=850, pulse_vel=5000, gripper_type=gripper, init_node=False)
        self.T_sy = np.identity(4)                                                              # Identical frame pose as the robot's base_link frame - but with the yaw of the end-effector
        self.T_yb = np.identity(4)                                                              # Identical to self.T_sb but with 0 yaw
        self.T_sb = np.identity(4)                                                              # Transform of the robot's end-effector w.r.t. the base_link frame
        self.T_sy_inc_cw = ang.poseToTransformationMatrix([0, 0, 0, 0, 0, -self.rotate_step])   # Matrix used to rotate self.T_sy clockwise by self.rotate_step radians
        self.T_sy_inc_ccw = ang.poseToTransformationMatrix([0, 0, 0, 0, 0, self.rotate_step])   # Matrix used to rotate self.T_sy counterclockwise by self.rotate_step radians
        self.home_positions = [0] * self.armbot.ux.dof                                          # Predefined 'Home' pose as specified in MoveIt
        self.hold_up_positions = [0] * self.armbot.ux.dof                                       # Predefined 'Hold-Up' pose as specified in MoveIt
        self.hold_up_positions[-2] = -math.pi/2
        self.robot_des = getattr(mrd, robot_model)
        if ee_offset is not None:                                                               # If setting an ee_offset, adjust the M matrix
            T_bf = ang.poseToTransformationMatrix(ee_offset)
            self.robot_des.M = np.dot(self.robot_des.M, T_bf)
        self.capture_joint_positions()                                                          # Get current joint states and determine self.T_sb
        rospy.Subscriber("commands/joy_processed", ArmJoy, self.joy_control_cb)
        rospy.Timer(rospy.Duration(1/45.0), self.arm_controller)

    ### @brief Helper function that updates the arm joint speed by increasing step size
    ### @param arm_inc - step factor multiplier (from 0 - 20)
    def update_arm_speed(self, arm_inc):
        self.arm_inc = arm_inc
        self.rotate_step = 0.005 + 0.00125 * self.arm_inc
        self.translate_step = 0.001 + 0.00045 * self.arm_inc
        self.T_sy_inc_cw[:2, :2] = ang.yawToRotationMatrix(-self.rotate_step)
        self.T_sy_inc_ccw[:2, :2] = ang.yawToRotationMatrix(self.rotate_step)
        rospy.loginfo("Current arm inc is %d." % self.arm_inc)

    ### @brief Helper function that updates the gripper speed by increasing step size
    ### @param gripper_inc - step increment (from 15 - 35)
    def update_gripper_speed(self, gripper_inc):
        self.gripper_inc = gripper_inc
        rospy.loginfo("Current gripper inc is %d." % self.gripper_inc)

    ### @brief Helper function that calculates the pose of the end-effector w.r.t. T_y
    def update_T_yb(self):
        rpy = ang.rotationMatrixToEulerAngles(self.T_sb[:3, :3])
        self.T_sy[:2,:2] = ang.yawToRotationMatrix(rpy[2])
        self.T_yb = np.dot(ang.transInv(self.T_sy), self.T_sb)

    ### @brief Helper function that moves the robot arm to the desired goal positions while in Mode 1
    ### @param goal_positions - desired joint goal positions [rad]
    def move_to_goal_positions(self, goal_positions):
        for x in range(self.armbot.ux.dof):
            diff = self.joint_positions[x] - goal_positions[x]
            sign = -1.0 if diff > 0 else 1.0
            step = 0.01 if x != self.armbot.ux.dof-2 else 0.03
            if (abs(diff) > step):
                self.joint_positions[x] += step * sign
            else:
                self.joint_positions[x] = goal_positions[x]
        self.armbot.ux.robot_move_servoj(self.joint_positions)

    ### @brief ROS Callback function that recieves ArmJoy messages
    ### @param msg - ArmJoy ROS message
    def joy_control_cb(self, msg):
        with self.joy_mutex:
            self.joy_msg = copy.deepcopy(msg)

        # Check the speed_cmd
        if (msg.speed_cmd == ArmJoy.SPEED_INC and self.arm_inc < 20):
            self.update_arm_speed(self.arm_inc + 1)
        elif (msg.speed_cmd == ArmJoy.SPEED_DEC and self.arm_inc > 0):
            self.update_arm_speed(self.arm_inc - 1)

        # Check the speed_toggle_cmd
        if (msg.speed_toggle_cmd == ArmJoy.SPEED_COARSE):
            self.arm_incs["fine"] = self.arm_inc
            rospy.loginfo("Switched to Coarse Control")
            self.update_arm_speed(self.arm_incs["coarse"])
        elif (msg.speed_toggle_cmd == ArmJoy.SPEED_FINE):
            self.arm_incs["coarse"] = self.arm_inc
            rospy.loginfo("Switched to Fine Control")
            self.update_arm_speed(self.arm_incs["fine"])

        # Check the gripper_speed_cmd
        if (msg.gripper_speed_cmd == ArmJoy.GRIPPER_SPEED_INC and self.gripper_inc < 35):
            self.update_gripper_speed(self.gripper_inc + 1)
        elif (msg.gripper_speed_cmd == ArmJoy.GRIPPER_SPEED_DEC and self.gripper_inc > 15):
            self.update_gripper_speed(self.gripper_inc - 1)

        # Check the mode_cmd
        if (msg.mode_cmd == ArmJoy.MODE_SERVO):
            self.armbot.ux.robot_smart_mode_reset(1)
            self.capture_joint_positions()
            self.current_mode_status = "servo"
        elif (msg.mode_cmd == ArmJoy.MODE_TEACH):
            self.armbot.ux.robot_smart_mode_reset(2)
            self.current_mode_status = "teach"

    ### @brief Sets the desired end-effector pose to the arm w.r.t. the Space frame
    ### @param T_sd - desired 4x4 transformation matrix
    ### @return <bool> - whether or not the desired pose was able to be achieved
    def set_xarm_ee_pose_matrix(self, T_sd):
        x = 1000 * T_sd[0,3]
        y = 1000 * T_sd[1,3]
        z = 1000 * T_sd[2,3]
        rpy = ang.rotationMatrixToEulerAngles(T_sd[:3,:3])
        cmd = [x, y, z, rpy[0], rpy[1], rpy[2]]
        ret = self.armbot.ux.robot_move_servo_cart(cmd)
        if ret == 0:
            self.T_sb = T_sd
            return True
        else:
            self.armbot.ux.robot_smart_mode_reset(1)
            return False

    ### @brief ROS Timer Callback function to operate the control loop
    ### @param event - TimerEvent message [unused]
    def arm_controller(self, event):

        ### if in 'teach' mode, arm is in zero gravity, so return...
        if (self.current_mode_status == "teach"): return

        with self.joy_mutex:
            msg = copy.deepcopy(self.joy_msg)

        # Check the gripper_cmd
        if (self.use_gripper == True):
            if (msg.gripper_cmd == ArmJoy.GRIPPER_OPEN and self.gripper_command < 850):
                self.gripper_command += self.gripper_inc
                if (self.gripper_command > 850): self.gripper_command = 850
                self.armbot.gripper.move(self.gripper_command, delay=0)
            elif (msg.gripper_cmd == ArmJoy.GRIPPER_CLOSE and self.gripper_command > 0):
                self.gripper_command -= self.gripper_inc
                if (self.gripper_command < 0): self.gripper_command = 0
                self.armbot.gripper.move(self.gripper_command, delay=0)

        # Check the pose_cmd
        if (msg.pose_cmd != 0):
            if self.initial_pose_cmd == True:
                rospy.sleep(0.2)
                self.capture_joint_positions()
                self.initial_pose_cmd = False
            if (msg.pose_cmd == ArmJoy.HOME_POSE):
                self.move_to_goal_positions(self.home_positions)
            elif (msg.pose_cmd == ArmJoy.HOLDUP_POSE):
                self.move_to_goal_positions(self.hold_up_positions)
            return

        if (self.initial_pose_cmd == False):
            self.initial_pose_cmd = True
            self.T_sb = mr.FKinSpace(self.robot_des.M, self.robot_des.Slist, self.joint_positions)
            self.update_T_yb()

        # Check the waist_cmd
        if (msg.waist_cmd != 0):
            rpy = ang.rotationMatrixToEulerAngles(self.T_sy[:3,:3])
            if (msg.waist_cmd == ArmJoy.WAIST_CCW):
                T_sy = np.dot(self.T_sy, self.T_sy_inc_ccw)
            elif (msg.waist_cmd == ArmJoy.WAIST_CW):
                T_sy = np.dot(self.T_sy, self.T_sy_inc_cw)
            T_sd = np.dot(T_sy, self.T_yb)
            success = self.set_xarm_ee_pose_matrix(T_sd)
            if (success):
                self.T_sy = np.array(T_sy)

        position_changed = msg.ee_x_cmd + msg.ee_y_cmd + msg.ee_z_cmd
        orientation_changed = msg.ee_yaw_cmd + msg.ee_pitch_cmd

        if (position_changed + orientation_changed == 0): return

        # Copy the most recent T_yb transform into a temporary variable
        T_yb = np.array(self.T_yb)

        if (position_changed):
            # check ee_x_cmd
            if (msg.ee_x_cmd == ArmJoy.EE_X_INC):
                T_yb[0, 3] += self.translate_step
            elif (msg.ee_x_cmd == ArmJoy.EE_X_DEC):
                T_yb[0, 3] -= self.translate_step

            # check ee_y_cmd
            if (msg.ee_y_cmd == ArmJoy.EE_Y_INC):
                T_yb[1, 3] += self.translate_step
            elif (msg.ee_y_cmd == ArmJoy.EE_Y_DEC):
                T_yb[1, 3] -= self.translate_step

            # check ee_z_cmd
            if (msg.ee_z_cmd == ArmJoy.EE_Z_INC):
                T_yb[2, 3] += self.translate_step
            elif (msg.ee_z_cmd == ArmJoy.EE_Z_DEC):
                T_yb[2, 3] -= self.translate_step

        # check end-effector orientation related commands
        if (orientation_changed != 0):
            rpy = ang.rotationMatrixToEulerAngles(T_yb[:3, :3])

            # check ee_pitch_cmd
            if (msg.ee_pitch_cmd == ArmJoy.EE_PITCH_DOWN):
                rpy[1] += self.rotate_step
            elif (msg.ee_pitch_cmd == ArmJoy.EE_PITCH_UP):
                rpy[1] -= self.rotate_step

            # check ee_yaw_cmd
            if (msg.ee_yaw_cmd == ArmJoy.EE_YAW_CCW):
                rpy[2] += self.rotate_step
            elif (msg.ee_yaw_cmd == ArmJoy.EE_YAW_CW):
                rpy[2] -= self.rotate_step

            T_yb[:3,:3] = ang.eulerAnglesToRotationMatrix(rpy)

        # Get desired transformation matrix of the end-effector w.r.t. the base frame
        T_sd = np.dot(self.T_sy, T_yb)
        success = self.set_xarm_ee_pose_matrix(T_sd)
        if (success):
            self.T_yb = np.array(T_yb)

    ### @brief Reads the current joint states and updates the current end-effector pose with them
    ### @details - should be used whenever joints are torqued off, right after torquing them on again
    def capture_joint_positions(self):
        self.joint_positions = []
        for name in self.armbot.ux.joint_names:
            self.joint_positions.append(self.armbot.ux.joint_states.position[self.armbot.ux.js_index_map[name]])
        self.T_sb = mr.FKinSpace(self.robot_des.M, self.robot_des.Slist, self.joint_positions)
        self.update_T_yb()

def main():
    rospy.init_node('uxarm_robot')
    bot = ArmRobot()
    rospy.spin()

if __name__=='__main__':
    main()
